package site.ruyi.menshen;

import javassist.CannotCompileException;
import javassist.ClassPool;
import javassist.CtClass;
import javassist.CtMethod;
import javassist.expr.ExprEditor;
import javassist.expr.MethodCall;
import javassist.expr.NewExpr;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.lang.instrument.ClassFileTransformer;
import java.lang.instrument.IllegalClassFormatException;
import java.lang.instrument.Instrumentation;
import java.nio.file.Files;
import java.security.ProtectionDomain;
import java.util.*;

public class MenshenAgent {

    private static final String DEFAULT_CONFIG_FILE = "menshen.properties"; //权限拦截配置

    private static final Properties config = new Properties();
    static {
        /**
         * 加载权限配置
         */
        try{
            config.load(MenshenAgent.class.getClassLoader().getResourceAsStream(DEFAULT_CONFIG_FILE));
        }catch (Exception e){
            if(!(e instanceof FileNotFoundException)){
                throw new RuntimeException("can not load menshen.properties !");
            }
        }
    }

    private static final Map<String,List<String>> CALL_PERMISSION_MAP = new HashMap<>(); //方法调用权限映射
    private static final Map<String,List<String>> NEW_PERMISSION_MAP = new HashMap<>(); //创建实例权限映射
    static {
        /**
         * 加载权限映射
         */
        String permissionFileSuffix = ".permission";
        String callActionFlag = "call ";
        String newActionFlag = "new ";
        String classpathDirStr = MenshenAgent.class.getClassLoader().getResource("").getFile();
        File classpathDir = new File(classpathDirStr);
        File[] permissionFiles = classpathDir.listFiles((File dir, String name)->name.endsWith(permissionFileSuffix));
        if(permissionFiles!=null){
            for(int i=0;i<permissionFiles.length;i++){
                File permissionFile = permissionFiles[i];
                String fileNameNoSuffix = permissionFile.getName().substring(0,permissionFile.getName().length() - permissionFileSuffix.length());
                try {
                    Files.readAllLines(permissionFile.toPath()).forEach((String line)->{
                        if(line==null || "".equals(line.trim())) return;
                        line = line.trim();
                        if(line.startsWith(callActionFlag)){
                            if(!CALL_PERMISSION_MAP.containsKey(fileNameNoSuffix)){
                                CALL_PERMISSION_MAP.put(fileNameNoSuffix,new ArrayList<>());
                            }
                            CALL_PERMISSION_MAP.get(fileNameNoSuffix).add(line.substring(callActionFlag.length()));
                        }else if(line.startsWith(newActionFlag)){
                            if(!NEW_PERMISSION_MAP.containsKey(fileNameNoSuffix)){
                                NEW_PERMISSION_MAP.put(fileNameNoSuffix,new ArrayList<>());
                            }
                            NEW_PERMISSION_MAP.get(fileNameNoSuffix).add(line.substring(newActionFlag.length()));
                        }
                    });
                } catch (IOException e) {
                    throw new RuntimeException("can not load permission file -> "+permissionFile.getAbsolutePath());
                }
            }
        }
    }

    // 缓存已查询过的权限结果，提高执行效率
    private static Map<String,Boolean> PERMISSION_MAP_CACHE = new HashMap<>();

    //检查类和包调用权限
    private static boolean hasPermission(String callContext,String permssionType){
        String configKey = callContext == null ? permssionType+".default":permssionType+"."+callContext;

        // 缓存直接拿结果
        if(PERMISSION_MAP_CACHE.containsKey(configKey)){
            return PERMISSION_MAP_CACHE.get(configKey);
        }

        if(!config.containsKey(configKey)){
            if(callContext==null) return true;
            if(callContext.indexOf('.')==-1){
                return hasPermission(null,permssionType);
            }else{
                return hasPermission(callContext.substring(0,callContext.lastIndexOf('.')),permssionType);
            }
        }

        String permission = config.getProperty(configKey);
        if(permission==null||"".equals(permission.trim())) permission = "deny";
        if("deny".equals(permission.trim())){
            PERMISSION_MAP_CACHE.put(configKey,false); //记录到缓存
            return false;
        }
        PERMISSION_MAP_CACHE.put(configKey,true); //记录到缓存
        return true;
    }

    //检查调用权限
    public static boolean checkCallPermission(String permissionType){
        StackTraceElement[] stackTraces = Thread.currentThread().getStackTrace();
        for(int i=0;i<stackTraces.length;i++){
            StackTraceElement stackTraceElement = stackTraces[i];
            if(!hasPermission(stackTraceElement.getClassName(),permissionType)){
                return false;
            }
        }
        return true;
    }

    /**
     * 检查javaapi方法调用
     */
    private static boolean buildCoreApiCheck(MethodCall m,String permissionType,List<String> coreApiList) throws CannotCompileException {
        if(permissionType==null||"".equals(permissionType)) return false;
        if(coreApiList==null||coreApiList.isEmpty()) return false;

        String callMethod = m.getClassName()+"."+m.getMethodName();
        if(!hasPermission(m.getClassName(),permissionType) && coreApiList.contains(callMethod)){
            m.replace(
                    "{throw new RuntimeException(\""+callMethod+" call deny by menshen!\");$_ = $proceed($$);}"
            );
            return true;
        }
        if(coreApiList.contains(callMethod)){
            m.replace(
                    "{if(!"+MenshenAgent.class.getCanonicalName()+".checkCallPermission(\""+permissionType+"\"))" +
                            "{throw new RuntimeException(\""+callMethod+" call deny by menshen!\");}$_ = $proceed($$);}");
            return true;
        }
        return false;
    }

    /**
     * 检查对象实例化调用
     */
    private static boolean buildCoreApiCheck(NewExpr e,String permissionType,List<String> coreApiList) throws CannotCompileException {
        String instanceClass = e.getClassName();
        if(!hasPermission(instanceClass,permissionType) && coreApiList.contains(instanceClass)){
            e.replace(
                    "{throw new RuntimeException(\""+instanceClass+" instance deny by menshen!\");$_ = $proceed($$);}"
            );
            return true;
        }
        if(coreApiList.contains(instanceClass)){
            e.replace(
                    "{if(!"+MenshenAgent.class.getCanonicalName()+".checkCallPermission(\""+permissionType+"\"))" +
                            "{throw new RuntimeException(\""+instanceClass+" instance deny by menshen!\");}$_ = $proceed($$);}");
            return true;
        }
        return false;
    }

    public static void premain(String agentArgs, Instrumentation inst){
        try{
            inst.addTransformer(new ClassFileTransformer(){

                @Override
                public byte[] transform(ClassLoader loader, String className, Class<?> classBeingRedefined, ProtectionDomain protectionDomain, byte[] classfileBuffer) throws IllegalClassFormatException {
                    try {
                        final String javaClassName = className.replace('/', '.');
                        //javaapi包跳过
                        if(javaClassName.startsWith("java.")
                                ||javaClassName.startsWith("javax.")
                                ||javaClassName.startsWith("sun.")
                                ||javaClassName.startsWith("com.sun.")
                            ) return null;

                        ClassPool classPool = ClassPool.getDefault();
                        CtClass ctClass = classPool.get(javaClassName);
                        if(ctClass.isAnnotation()||ctClass.isInterface()){ //注解类型和接口类型无需处理
                            ctClass.detach();
                            return null;
                        }

                        CtMethod[] ctMethods = ctClass.getDeclaredMethods();
                        for(int i=0;i<ctMethods.length;i++){
                            ctMethods[i].instrument(new ExprEditor(){
                                @Override
                                public void edit(MethodCall m) throws CannotCompileException {
                                    /**
                                     * 分析字节码，检查javaApi关键方法调用
                                     */
                                    for(String permissionType:CALL_PERMISSION_MAP.keySet()){
                                        if(buildCoreApiCheck(m,permissionType,CALL_PERMISSION_MAP.get(permissionType))) return;
                                    }
                                }

                                @Override
                                public void edit(NewExpr e) throws CannotCompileException {
                                    /**
                                     * 分析字节码，检查javaApi关键对象实例化
                                     */
                                    for(String permissionType:NEW_PERMISSION_MAP.keySet()){
                                        if(buildCoreApiCheck(e,permissionType,NEW_PERMISSION_MAP.get(permissionType))) return;
                                    }
                                }
                            });
                        }
                        byte[] bytecode = ctClass.toBytecode();
                        //ctClass.debugWriteFile("debug-output"); //输出修改后的class用于检查和调试
                        ctClass.detach();
                        return bytecode;
                    } catch (Exception e) {
                        e.printStackTrace();
                    }
                    return null;
                }
            },true);
        }catch (Exception e){
            e.printStackTrace();
        }
    }

    public static void agentmain(String agentArgs, Instrumentation inst){

    }
}
